// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package inspector_test

import (
	"go/ast"
	"go/build"
	"go/parser"
	"go/token"
	"log"
	"path/filepath"
	"reflect"
	"strconv"
	"strings"
	"testing"

	"golang.org/x/tools/go/ast/inspector"
)

var netFiles []*ast.File

func init() {
	files, err := parseNetFiles()
	if err != nil {
		log.Fatal(err)
	}
	netFiles = files
}

func parseNetFiles() ([]*ast.File, error) {
	pkg, err := build.Default.Import("net", "", 0)
	if err != nil {
		return nil, err
	}
	fset := token.NewFileSet()
	var files []*ast.File
	for _, filename := range pkg.GoFiles {
		filename = filepath.Join(pkg.Dir, filename)
		f, err := parser.ParseFile(fset, filename, nil, 0)
		if err != nil {
			return nil, err
		}
		files = append(files, f)
	}
	return files, nil
}

// TestInspectAllNodes compares Inspector against ast.Inspect.
func TestInspectAllNodes(t *testing.T) {
	inspect := inspector.New(netFiles)

	var nodesA []ast.Node
	inspect.Nodes(nil, func(n ast.Node, push bool) bool {
		if push {
			nodesA = append(nodesA, n)
		}
		return true
	})
	var nodesB []ast.Node
	for _, f := range netFiles {
		ast.Inspect(f, func(n ast.Node) bool {
			if n != nil {
				nodesB = append(nodesB, n)
			}
			return true
		})
	}
	compare(t, nodesA, nodesB)
}

func TestInspectGenericNodes(t *testing.T) {
	// src is using the 16 identifiers i0, i1, ... i15 so
	// we can easily verify that we've found all of them.
	const src = `package a

type I interface { ~i0|i1 }

type T[i2, i3 interface{ ~i4 }] struct {}

func f[i5, i6 any]() {
	_ = f[i7, i8]
	var x T[i9, i10]
}

func (*T[i11, i12]) m()

var _ i13[i14, i15]
`
	fset := token.NewFileSet()
	f, _ := parser.ParseFile(fset, "a.go", src, 0)
	inspect := inspector.New([]*ast.File{f})
	found := make([]bool, 16)

	indexListExprs := make(map[*ast.IndexListExpr]bool)

	// Verify that we reach all i* identifiers, and collect IndexListExpr nodes.
	inspect.Preorder(nil, func(n ast.Node) {
		switch n := n.(type) {
		case *ast.Ident:
			if n.Name[0] == 'i' {
				index, err := strconv.Atoi(n.Name[1:])
				if err != nil {
					t.Fatal(err)
				}
				found[index] = true
			}
		case *ast.IndexListExpr:
			indexListExprs[n] = false
		}
	})
	for i, v := range found {
		if !v {
			t.Errorf("missed identifier i%d", i)
		}
	}

	// Verify that we can filter to IndexListExprs that we found in the first
	// step.
	if len(indexListExprs) == 0 {
		t.Fatal("no index list exprs found")
	}
	inspect.Preorder([]ast.Node{&ast.IndexListExpr{}}, func(n ast.Node) {
		ix := n.(*ast.IndexListExpr)
		indexListExprs[ix] = true
	})
	for ix, v := range indexListExprs {
		if !v {
			t.Errorf("inspected node %v not filtered", ix)
		}
	}
}

// TestInspectPruning compares Inspector against ast.Inspect,
// pruning descent within ast.CallExpr nodes.
func TestInspectPruning(t *testing.T) {
	inspect := inspector.New(netFiles)

	var nodesA []ast.Node
	inspect.Nodes(nil, func(n ast.Node, push bool) bool {
		if push {
			nodesA = append(nodesA, n)
			_, isCall := n.(*ast.CallExpr)
			return !isCall // don't descend into function calls
		}
		return false
	})
	var nodesB []ast.Node
	for _, f := range netFiles {
		ast.Inspect(f, func(n ast.Node) bool {
			if n != nil {
				nodesB = append(nodesB, n)
				_, isCall := n.(*ast.CallExpr)
				return !isCall // don't descend into function calls
			}
			return false
		})
	}
	compare(t, nodesA, nodesB)
}

// compare calls t.Error if !slices.Equal(nodesA, nodesB).
func compare[N comparable](t *testing.T, nodesA, nodesB []N) {
	if len(nodesA) != len(nodesB) {
		t.Errorf("inconsistent node lists: %d vs %d", len(nodesA), len(nodesB))
	} else {
		for i := range nodesA {
			if a, b := nodesA[i], nodesB[i]; a != b {
				t.Errorf("node %d is inconsistent: %T, %T", i, a, b)
			}
		}
	}
}

func TestTypeFiltering(t *testing.T) {
	const src = `package a
func f() {
	print("hi")
	panic("oops")
}
`
	fset := token.NewFileSet()
	f, _ := parser.ParseFile(fset, "a.go", src, 0)
	inspect := inspector.New([]*ast.File{f})

	var got []string
	fn := func(n ast.Node, push bool) bool {
		if push {
			got = append(got, typeOf(n))
		}
		return true
	}

	// no type filtering
	inspect.Nodes(nil, fn)
	if want := strings.Fields("File Ident FuncDecl Ident FuncType FieldList BlockStmt ExprStmt CallExpr Ident BasicLit ExprStmt CallExpr Ident BasicLit"); !reflect.DeepEqual(got, want) {
		t.Errorf("inspect: got %s, want %s", got, want)
	}

	// type filtering
	nodeTypes := []ast.Node{
		(*ast.BasicLit)(nil),
		(*ast.CallExpr)(nil),
	}
	got = nil
	inspect.Nodes(nodeTypes, fn)
	if want := strings.Fields("CallExpr BasicLit CallExpr BasicLit"); !reflect.DeepEqual(got, want) {
		t.Errorf("inspect: got %s, want %s", got, want)
	}

	// inspect with stack
	got = nil
	inspect.WithStack(nodeTypes, func(n ast.Node, push bool, stack []ast.Node) bool {
		if push {
			var line []string
			for _, n := range stack {
				line = append(line, typeOf(n))
			}
			got = append(got, strings.Join(line, " "))
		}
		return true
	})
	want := []string{
		"File FuncDecl BlockStmt ExprStmt CallExpr",
		"File FuncDecl BlockStmt ExprStmt CallExpr BasicLit",
		"File FuncDecl BlockStmt ExprStmt CallExpr",
		"File FuncDecl BlockStmt ExprStmt CallExpr BasicLit",
	}
	if !reflect.DeepEqual(got, want) {
		t.Errorf("inspect: got %s, want %s", got, want)
	}
}

func typeOf(n ast.Node) string {
	return strings.TrimPrefix(reflect.TypeOf(n).String(), "*ast.")
}

// The numbers show a marginal improvement (ASTInspect/Inspect) of 3.5x,
// but a break-even point (NewInspector/(ASTInspect-Inspect)) of about 5
// traversals.
//
// BenchmarkASTInspect     1.0 ms
// BenchmarkNewInspector   2.2 ms
// BenchmarkInspect        0.39ms
// BenchmarkInspectFilter  0.01ms
// BenchmarkInspectCalls   0.14ms

func BenchmarkNewInspector(b *testing.B) {
	// Measure one-time construction overhead.
	for i := 0; i < b.N; i++ {
		inspector.New(netFiles)
	}
}

func BenchmarkInspect(b *testing.B) {
	b.StopTimer()
	inspect := inspector.New(netFiles)
	b.StartTimer()

	// Measure marginal cost of traversal.
	var ndecls, nlits int
	for i := 0; i < b.N; i++ {
		inspect.Preorder(nil, func(n ast.Node) {
			switch n.(type) {
			case *ast.FuncDecl:
				ndecls++
			case *ast.FuncLit:
				nlits++
			}
		})
	}
}

func BenchmarkInspectFilter(b *testing.B) {
	b.StopTimer()
	inspect := inspector.New(netFiles)
	b.StartTimer()

	// Measure marginal cost of traversal.
	nodeFilter := []ast.Node{(*ast.FuncDecl)(nil), (*ast.FuncLit)(nil)}
	var ndecls, nlits int
	for i := 0; i < b.N; i++ {
		inspect.Preorder(nodeFilter, func(n ast.Node) {
			switch n.(type) {
			case *ast.FuncDecl:
				ndecls++
			case *ast.FuncLit:
				nlits++
			}
		})
	}
}

func BenchmarkInspectCalls(b *testing.B) {
	b.StopTimer()
	inspect := inspector.New(netFiles)
	b.StartTimer()

	// Measure marginal cost of traversal.
	nodeFilter := []ast.Node{(*ast.CallExpr)(nil)}
	var ncalls int
	for i := 0; i < b.N; i++ {
		inspect.Preorder(nodeFilter, func(n ast.Node) {
			_ = n.(*ast.CallExpr)
			ncalls++
		})
	}
}

func BenchmarkASTInspect(b *testing.B) {
	var ndecls, nlits int
	for i := 0; i < b.N; i++ {
		for _, f := range netFiles {
			ast.Inspect(f, func(n ast.Node) bool {
				switch n.(type) {
				case *ast.FuncDecl:
					ndecls++
				case *ast.FuncLit:
					nlits++
				}
				return true
			})
		}
	}
}
